{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e0336840",
   "metadata": {},
   "source": [
    "# Introduction to Taxi ETL Job\n",
    "This is the Taxi ETL job to generate the input datasets for the Taxi XGBoost job."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86fd8ad9",
   "metadata": {},
   "source": [
    "## Prerequirement\n",
    "### 1. Download data\n",
    "All data could be found at https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page\n",
    "\n",
    "### 2. Download needed jar\n",
    "* [rapids-4-spark_2.12-25.04.0.jar](https://repo1.maven.org/maven2/com/nvidia/rapids-4-spark_2.12/25.04.0/rapids-4-spark_2.12-25.04.0.jar)\n",
    "\n",
    "### 3. Start Spark Standalone\n",
    "Before running the script, please setup Spark standalone mode\n",
    "\n",
    "### 4. Add ENV\n",
    "```\n",
    "$ export SPARK_JARS=rapids-4-spark_2.12-25.04.0.jar\n",
    "\n",
    "```\n",
    "\n",
    "### 5.Start Jupyter Notebook with spylon-kernel or toree\n",
    "\n",
    "```\n",
    "$ jupyter notebook --allow-root --notebook-dir=${your-dir} --config=${your-configs}\n",
    "```\n",
    "\n",
    "## Import Libs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1e50cfad",
   "metadata": {},
   "outputs": [],
   "source": [
    "import org.apache.spark.sql.SparkSession\n",
    "import org.apache.spark.sql.DataFrame\n",
    "import org.apache.spark.sql.functions._\n",
    "import org.apache.spark.sql.types.DataTypes.{DoubleType, IntegerType, StringType}\n",
    "import org.apache.spark.sql.types.{FloatType, StructField, StructType}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24f69140",
   "metadata": {},
   "source": [
    "## Script Settings\n",
    "\n",
    "### 1. File Path Settings\n",
    "* Define input file path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "317b9415",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "lastException = null\n",
       "dataRoot = /data\n",
       "rawPath = /data/taxi/taxi-etl-input-small.csv\n",
       "outPath = /data/datasets/taxi/output\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "/data/taxi/output"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val dataRoot = sys.env.getOrElse(\"DATA_ROOT\", \"/data\")\n",
    "val rawPath = dataRoot + \"/taxi/taxi-etl-input-small.csv\"\n",
    "val outPath = dataRoot + \"/taxi/output\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f036d30",
   "metadata": {},
   "source": [
    "## Function and Object Define\n",
    "### Define the constants\n",
    "\n",
    "* Define input file schema"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "acc23ac1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "rawSchema = StructType(StructField(vendor_id,StringType,true), StructField(pickup_datetime,StringType,true), StructField(dropoff_datetime,StringType,true), StructField(passenger_count,IntegerType,true), StructField(trip_distance,DoubleType,true), StructField(pickup_longitude,DoubleType,true), StructField(pickup_latitude,DoubleType,true), StructField(rate_code,StringType,true), StructField(store_and_fwd_flag,StringType,true), StructField(dropoff_longitude,DoubleType,true), StructField(dropoff_latitude,DoubleType,true), StructField(payment_type,StringType,true), StructField(fare_amount,DoubleType,true), StructField(surcharge,DoubleType,true), StructField(mta_tax,DoubleType,true), StructField(tip_amount,DoubleType,true), StructField(tolls_amount,Doubl...\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "StructType(StructField(vendor_id,StringType,true), StructField(pickup_datetime,StringType,true), StructField(dropoff_datetime,StringType,true), StructField(passenger_count,IntegerType,true), StructField(trip_distance,DoubleType,true), StructField(pickup_longitude,DoubleType,true), StructField(pickup_latitude,DoubleType,true), StructField(rate_code,StringType,true), StructField(store_and_fwd_flag,StringType,true), StructField(dropoff_longitude,DoubleType,true), StructField(dropoff_latitude,DoubleType,true), StructField(payment_type,StringType,true), StructField(fare_amount,DoubleType,true), StructField(surcharge,DoubleType,true), StructField(mta_tax,DoubleType,true), StructField(tip_amount,DoubleType,true), StructField(tolls_amount,Doubl..."
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val rawSchema = StructType(Seq(\n",
    "    StructField(\"vendor_id\", StringType),\n",
    "    StructField(\"pickup_datetime\", StringType),\n",
    "    StructField(\"dropoff_datetime\", StringType),\n",
    "    StructField(\"passenger_count\", IntegerType),\n",
    "    StructField(\"trip_distance\", DoubleType),\n",
    "    StructField(\"pickup_longitude\", DoubleType),\n",
    "    StructField(\"pickup_latitude\", DoubleType),\n",
    "    StructField(\"rate_code\", StringType),\n",
    "    StructField(\"store_and_fwd_flag\", StringType),\n",
    "    StructField(\"dropoff_longitude\", DoubleType),\n",
    "    StructField(\"dropoff_latitude\", DoubleType),\n",
    "    StructField(\"payment_type\", StringType),\n",
    "    StructField(\"fare_amount\", DoubleType),\n",
    "    StructField(\"surcharge\", DoubleType),\n",
    "    StructField(\"mta_tax\", DoubleType),\n",
    "    StructField(\"tip_amount\", DoubleType),\n",
    "    StructField(\"tolls_amount\", DoubleType),\n",
    "    StructField(\"total_amount\", DoubleType)\n",
    "  ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2e467519",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "trainRatio = 80\n",
       "evalRatio = 20\n",
       "trainEvalRatio = 0\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "dataRatios: (Int, Int, Int)\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def dataRatios: (Int, Int, Int) = {\n",
    "    val ratios = (80, 20)\n",
    "    (ratios._1, ratios._2, 100 - ratios._1 - ratios._2)\n",
    "  }\n",
    "val (trainRatio, evalRatio, trainEvalRatio) = dataRatios"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c2024d7",
   "metadata": {},
   "source": [
    "* Build the spark session and dataframe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b551ca1d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "sparkSession = org.apache.spark.sql.SparkSession@68530eb7\n",
       "df = [vendor_id: string, pickup_datetime: string ... 16 more fields]\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "[vendor_id: string, pickup_datetime: string ... 16 more fields]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "// Build the spark session and data reader as usual\n",
    "val sparkSession = SparkSession.builder.appName(\"taxi-etl\").getOrCreate\n",
    "val df = sparkSession.read.option(\"header\", true).schema(rawSchema).csv(rawPath)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f50ff7d",
   "metadata": {},
   "source": [
    "* Define some ETL functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3ca5738f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dropUseless: (dataFrame: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def dropUseless(dataFrame: DataFrame): DataFrame = {\n",
    "    dataFrame.drop(\n",
    "      \"dropoff_datetime\",\n",
    "      \"payment_type\",\n",
    "      \"surcharge\",\n",
    "      \"mta_tax\",\n",
    "      \"tip_amount\",\n",
    "      \"tolls_amount\",\n",
    "      \"total_amount\")\n",
    "  }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "852b06c3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "encodeCategories: (dataFrame: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def encodeCategories(dataFrame: DataFrame): DataFrame = {\n",
    "    val categories = Seq(\"vendor_id\", \"rate_code\", \"store_and_fwd_flag\")\n",
    "\n",
    "    (categories.foldLeft(dataFrame) {\n",
    "      case (df, category) => df.withColumn(category, hash(col(category)))\n",
    "    }).withColumnRenamed(\"store_and_fwd_flag\", \"store_and_fwd\")\n",
    "  }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "dbf0ab75",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "fillNa: (dataFrame: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def fillNa(dataFrame: DataFrame): DataFrame = {\n",
    "    dataFrame.na.fill(-1)\n",
    "  }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "39308a05",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "removeInvalid: (dataFrame: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def removeInvalid(dataFrame: DataFrame): DataFrame = {\n",
    "    val conditions = Seq(\n",
    "      Seq(\"fare_amount\", 0, 500),\n",
    "      Seq(\"passenger_count\", 0, 6),\n",
    "      Seq(\"pickup_longitude\", -75, -73),\n",
    "      Seq(\"dropoff_longitude\", -75, -73),\n",
    "      Seq(\"pickup_latitude\", 40, 42),\n",
    "      Seq(\"dropoff_latitude\", 40, 42))\n",
    "\n",
    "    conditions\n",
    "      .map { case Seq(column, min, max) => \"%s > %d and %s < %d\".format(column, min, column, max) }\n",
    "      .foldLeft(dataFrame) {\n",
    "        _.filter(_)\n",
    "      }\n",
    "  }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "11cd052b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "convertDatetime: (dataFrame: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def convertDatetime(dataFrame: DataFrame): DataFrame = {\n",
    "    val datetime = col(\"pickup_datetime\")\n",
    "    dataFrame\n",
    "      .withColumn(\"pickup_datetime\", to_timestamp(datetime))\n",
    "      .withColumn(\"year\", year(datetime))\n",
    "      .withColumn(\"month\", month(datetime))\n",
    "      .withColumn(\"day\", dayofmonth(datetime))\n",
    "      .withColumn(\"day_of_week\", dayofweek(datetime))\n",
    "      .withColumn(\n",
    "        \"is_weekend\",\n",
    "        col(\"day_of_week\").isin(1, 7).cast(IntegerType)) // 1: Sunday, 7: Saturday\n",
    "      .withColumn(\"hour\", hour(datetime))\n",
    "      .drop(datetime.toString)\n",
    "  }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "71e1b568",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "addHDistance: (dataFrame: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def addHDistance(dataFrame: DataFrame): DataFrame = {\n",
    "    val P = math.Pi / 180\n",
    "    val lat1 = col(\"pickup_latitude\")\n",
    "    val lon1 = col(\"pickup_longitude\")\n",
    "    val lat2 = col(\"dropoff_latitude\")\n",
    "    val lon2 = col(\"dropoff_longitude\")\n",
    "    val internalValue = (lit(0.5)\n",
    "      - cos((lat2 - lat1) * P) / 2\n",
    "      + cos(lat1 * P) * cos(lat2 * P) * (lit(1) - cos((lon2 - lon1) * P)) / 2)\n",
    "    val hDistance = lit(12734) * asin(sqrt(internalValue))\n",
    "    dataFrame.withColumn(\"h_distance\", hDistance)\n",
    "  }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fe805d5",
   "metadata": {},
   "source": [
    "* Define main ETL function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "6da3b832",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "preProcess: (dataFrame: org.apache.spark.sql.DataFrame, splits: Array[Int])Array[org.apache.spark.sql.DataFrame]\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def preProcess(dataFrame: DataFrame, splits: Array[Int]): Array[DataFrame] = {\n",
    "    val processes = Seq[DataFrame => DataFrame](\n",
    "      dropUseless,\n",
    "      encodeCategories,\n",
    "      fillNa,\n",
    "      removeInvalid,\n",
    "      convertDatetime,\n",
    "      addHDistance\n",
    "    )\n",
    "\n",
    "    processes\n",
    "      .foldLeft(dataFrame) { case (df, process) => process(df) }\n",
    "      .randomSplit(splits.map(_.toDouble))\n",
    "  }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "85541b03",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dataset = Array([vendor_id: int, passenger_count: int ... 15 more fields], [vendor_id: int, passenger_count: int ... 15 more fields], [vendor_id: int, passenger_count: int ... 15 more fields])\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "Array([vendor_id: int, passenger_count: int ... 15 more fields], [vendor_id: int, passenger_count: int ... 15 more fields], [vendor_id: int, passenger_count: int ... 15 more fields])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val dataset = preProcess(df, Array(trainRatio, trainEvalRatio, evalRatio))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6787cac7",
   "metadata": {},
   "source": [
    "## Run ETL Process and Save the Result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "371886e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Elapsed time : 4.371s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "t0 = 1654139600797\n",
       "t1 = 1654139605168\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "1654139605168"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val t0 = System.currentTimeMillis\n",
    "for ((name, index) <- Seq(\"train\", \"eval\", \"trans\").zipWithIndex) {\n",
    "        dataset(index).write.mode(\"overwrite\").parquet(outPath + \"/parquet/\" + name)\n",
    "        dataset(index).write.mode(\"overwrite\").csv(outPath + \"/csv/\" + name)\n",
    "      }\n",
    "val t1 = System.currentTimeMillis\n",
    "println(\"Elapsed time : \" + ((t1 - t0).toFloat / 1000) + \"s\")\n",
    "sparkSession.stop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d89fa1b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "XGBoost4j-Spark - Scala",
   "language": "scala",
   "name": "XGBoost4j-Spark_scala"
  },
  "language_info": {
   "codemirror_mode": "text/x-scala",
   "file_extension": ".scala",
   "mimetype": "text/x-scala",
   "name": "scala",
   "pygments_lexer": "scala",
   "version": "2.12.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
